package com.we.wfc.common.utils;

import com.google.common.collect.Lists;
import com.we.wfc.common.annotation.*;
import com.we.wfc.common.comparator.TreeCompartor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;

import java.lang.reflect.Field;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * @Description: 树工具类
 * @Author:Liangzy(Feeling)
 * @Date:Create in 2019/12/30 2:25 下午
 */
@Slf4j
public class TreeUtils {

    /**
     * 使用某条件查询某个接口(不包含子列表)
     *
     * @param dirs
     *            tree列表
     * @param func
     *            条件函数
     * @return T
     */
    public static <T> T findNode(List<T> dirs, Function<T, T> func) {
        return findNode(dirs, null, true, func);
    }

    /**
     * 使用某条件查询某个接口(不包含子列表)(当T和E不同时，T的属性名要和E完全相同才会进行正确转换)
     *
     * @param dirs
     *            tree列表
     * @param func
     *            条件函数
     * @return T
     */
    public static <T, E> T findNode(List<T> dirs, Class<E> clazz, Function<T, T> func) {
        return findNode(dirs, clazz, true, func);
    }

    /**
     * 使用某条件查询某个接口
     *
     * @param dirs
     *            tree列表
     * @param excludeChild
     *            是否不需要子list
     * @param func
     *            条件函数
     * @return T
     */
    public static <T, E> T findNode(List<T> dirs, boolean excludeChild, Function<T, T> func) {
        return findNode(dirs, null, excludeChild, func);
    }

    /**
     * 使用某条件查询某个接口(当T和E不同时，T的属性名要和E完全相同才会进行正确转换)
     *
     * @param dirs
     *            tree列表
     * @param excludeChild
     *            是否不需要子list
     * @param func
     *            条件函数
     * @return T
     */
    @SuppressWarnings("unchecked")
    public static <T, E> T findNode(List<T> dirs, Class<E> clazz, boolean excludeChild, Function<T, T> func) {
        try {
            if (ConverterUtil.isEmpty(dirs)) {
                return null;
            }
            T findRet = null;
            Field childrenField = null;
            if (null == clazz) {
                // 查询标记了TreeChildren的属性
                List<Field> fieldArray = ConverterUtil.getAllFieldsByAnnotation(dirs.get(0).getClass(), TreeChildren.class);
                if (ConverterUtil.isNotEmpty(fieldArray)) {
                    childrenField = fieldArray.get(0);
                    childrenField.setAccessible(true);
                }
            } else {
                // 查询标记了TreeChildren的属性
                List<Field> fieldArray = ConverterUtil.getAllFieldsByAnnotation(clazz, TreeChildren.class);
                if (ConverterUtil.isNotEmpty(fieldArray)) {
                    String nodeAttrName = fieldArray.get(0).getName();
                    Field[] tarFields = ConverterUtil.getAllFields(dirs.get(0).getClass());
                    Map<String, Field> nodeFieldMap = Lists.newArrayList(tarFields).stream().collect(Collectors.toMap(Field::getName, v -> v, (f1, f2) -> f1));
                    childrenField = nodeFieldMap.get(nodeAttrName);
                    if (ConverterUtil.isNotEmpty(childrenField)) {
                        childrenField.setAccessible(true);
                    }
                }
            }
            for (T t : dirs) {
                T res = func.apply(t);
                if (ConverterUtil.isNotEmpty(res)) {
                    T ret = (T) res.getClass().newInstance();
                    String exclude = "";
                    if (excludeChild && ConverterUtil.isNotEmpty(childrenField)) {
                        exclude = childrenField.getName();
                    }
                    ConverterUtil.copyProperties(res, ret, exclude);
                    return ret;
                } else {
                    if (ConverterUtil.isNotEmpty(childrenField)) {
                        List<T> childList = (List<T>) childrenField.get(t);
                        if (ConverterUtil.isNotEmpty(childList)) {
                            findRet = findNode(childList, clazz, excludeChild, func);
                        }
                        if (ConverterUtil.isNotEmpty(findRet)) {
                            return findRet;
                        }
                    }
                }
            }
        } catch (Exception e) {
            return null;
        }
        return null;
    }

    /**
     * 按树等级查找节点
     *
     * @param dirs
     * @param level
     * @return
     */
    public static <T> List<T> getLevList(List<T> dirs, int level) {
        if (ConverterUtil.isNotEmpty(dirs)) {
            List<T> result = Lists.newArrayList();
            T node = dirs.get(0);
            Field anLevel = null;
            Field anChildren = null;
            Field[] fields = ConverterUtil.getAllFields(node.getClass());
            for (Field field : fields) {
                field.setAccessible(true);
                if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeLevel.class))) {
                    anLevel = field;
                }
                if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeChildren.class))) {
                    anChildren = field;
                }
            }
            if (ConverterUtil.isEmpty(anLevel)) {
                return dirs;
            }
            try {
                for (T t : dirs) {
                    int lev = ConverterUtil.toInteger(anLevel.get(t));
                    if (lev == level) {
                        result.add(t);
                    }
                    if (lev < level) {
                        List<T> chList = getLevChildren(t, level, anLevel, anChildren);
                        if (ConverterUtil.isNotEmpty(chList)) {
                            result.addAll(chList);
                        }
                    }
                }
                return result;
            } catch (Exception e) {
                return null;
            }
        }
        return null;
    }

    /**
     * 按等级查找子节点
     *
     * @param ent
     * @param level
     * @param anLevel
     * @param anChildren
     * @return
     */
    @SuppressWarnings("unchecked")
    private static <T> List<T> getLevChildren(T ent, int level, Field anLevel, Field anChildren) {
        try {
            List<T> result = Lists.newArrayList();
            if (ConverterUtil.isNotEmpty(ent)) {
                List<T> chList = (List<T>) anChildren.get(ent);
                if (ConverterUtil.isNotEmpty(chList)) {
                    for (T t : chList) {
                        int lev = ConverterUtil.toInteger(anLevel.get(t));
                        if (lev == level) {
                            result.add(t);
                        }
                        if (lev < level) {
                            result.addAll(getLevChildren(t, level, anLevel, anChildren));
                        }
                    }
                    return result;
                }
                return null;
            }
        } catch (Exception e) {
            return null;
        }
        return null;
    }

    /**
     * 树结构转list
     *
     * @param dirs
     * @return
     */
    public static <T> List<T> treeToList(List<T> dirs) {
        return treeToList(dirs, null);
    }

    /**
     * 树结构转list(当T和E不同时，T的属性名要和E完全相同才会进行正确转换 且 T 必须有一个public的无参数构造函数)
     *
     * @param dirs
     * @return
     */
    @SuppressWarnings("unchecked")
    public static <T, E> List<T> treeToList(List<T> dirs, Class<E> clazz) {
        if (ConverterUtil.isEmpty(dirs)) {
            return null;
        }
        Field anChildren = null;
        if (null == clazz) {
            Class<?> nodeClazz = dirs.get(0).getClass();
            List<Field> childrenList = ConverterUtil.getAllFieldsByAnnotation(nodeClazz, TreeChildren.class);
            if (ConverterUtil.isEmpty(childrenList)) {
                throw new RuntimeException(nodeClazz + " is not contains @TreeChildren field");
            }
            anChildren = childrenList.get(0);
            anChildren.setAccessible(true);
        } else {
            Class<?> nodeClazz = dirs.get(0).getClass();
            List<Field> childrenList = ConverterUtil.getAllFieldsByAnnotation(clazz, TreeChildren.class);
            if (ConverterUtil.isEmpty(childrenList)) {
                throw new RuntimeException(clazz + " is not contains @TreeChildren field");
            }
            Field[] nodeFields = ConverterUtil.getAllFields(nodeClazz);
            Map<String, Field> nodeFieldMap = Lists.newArrayList(nodeFields).stream().collect(Collectors.toMap(Field::getName, v -> v, (f1, f2) -> f1));
            String nodeAttrName = childrenList.get(0).getName();
            anChildren = nodeFieldMap.get(nodeAttrName);
            if (null == anChildren) {
                throw new RuntimeException(nodeClazz + " is not contains [" + nodeAttrName + "] field");
            }
            anChildren.setAccessible(true);
        }
        List<T> result = Lists.newArrayList();
        try {
            for (T t : dirs) {
                // 添加到返回list
                T newEntity = (T) dirs.get(0).getClass().newInstance();
                ConverterUtil.copyProperties(t, newEntity, anChildren.getName());
                result.add(newEntity);
                // 如果有子列表就递归
                List<T> childList = (List<T>) anChildren.get(t);
                if (ConverterUtil.isNotEmpty(childList)) {
                    List<T> childResult = treeToList(childList, clazz);
                    result.addAll(childResult);
                }
            }
        } catch (Exception e) {
            log.error("exception in treeToList", e);
        }
        return result;
    }

    /**
     * 将List转换为树结构
     *
     * @param dirs
     * @return
     */
    public static <T> List<T> listToTree(List<T> dirs) {
        return listToTree(dirs, null);
    }

    /**
     * 将List转换为树结构(当T和E不同时，T的属性名要和E完全相同才会进行正确转换)
     *
     * @param dirs
     * @return
     */
    @SuppressWarnings("unchecked")
    public static <T, E> List<T> listToTree(List<T> dirs, Class<E> clazz) {
        Field anId = null;
        Field anPid = null;
        // Field anName = null;
        Field anLevel = null;
        Field anSort = null;
        Field anParent = null;
        Field anChildren = null;
        Comparator<T> compartor = null;
        try {
            if (ConverterUtil.isNotEmpty(dirs)) {
                if (null == clazz) {
                    T node = dirs.get(0);
                    TreeComparator coustomComp = node.getClass().getAnnotation(TreeComparator.class);
                    if (null != coustomComp) {
                        compartor = coustomComp.value().newInstance();
                    }
                    Field[] fields = ConverterUtil.getAllFields(node.getClass());
                    for (Field field : fields) {
                        field.setAccessible(true);
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeId.class))) {
                            anId = field;
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreePid.class))) {
                            anPid = field;
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeLevel.class))) {
                            anLevel = field;
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeSort.class))) {
                            anSort = field;
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeParent.class))) {
                            anParent = field;
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeChildren.class))) {
                            anChildren = field;
                        }
                    }
                } else {
                    T node = dirs.get(0);
                    Class<?> nodeClazz = node.getClass();
                    TreeComparator coustomComp = clazz.getAnnotation(TreeComparator.class);
                    if (null != coustomComp) {
                        compartor = coustomComp.value().newInstance();
                    }
                    Field[] fields = ConverterUtil.getAllFields(clazz);
                    Field[] nodeFields = ConverterUtil.getAllFields(nodeClazz);
                    Map<String, Field> nodeFieldMap = Lists.newArrayList(nodeFields).stream().collect(Collectors.toMap(Field::getName, v -> v, (f1, f2) -> f1));
                    for (Field field : fields) {

                        String fieldName = field.getName();
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeId.class))) {
                            anId = nodeFieldMap.get(fieldName);
                            if (null != anId) {
                                anId.setAccessible(true);
                            }
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreePid.class))) {
                            anPid = nodeFieldMap.get(fieldName);
                            if (null != anPid) {
                                anPid.setAccessible(true);
                            }
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeLevel.class))) {
                            anLevel = nodeFieldMap.get(fieldName);
                            if (null != anLevel) {
                                anLevel.setAccessible(true);
                            }
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeSort.class))) {
                            anSort = nodeFieldMap.get(fieldName);
                            if (null != anSort) {
                                anSort.setAccessible(true);
                            }
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeParent.class))) {
                            anParent = nodeFieldMap.get(fieldName);
                            if (null != anParent) {
                                anParent.setAccessible(true);
                            }
                        }
                        if (ConverterUtil.isNotEmpty(field.getAnnotation(TreeChildren.class))) {
                            anChildren = nodeFieldMap.get(fieldName);
                            if (null != anChildren) {
                                anChildren.setAccessible(true);
                            }
                        }
                    }
                }
            } else {
                return dirs;
            }
            if (!ConverterUtil.isNotEmpty(anId, anPid, anParent, anChildren)) {
                throw new RuntimeException("The TreeId,TreePid,TreeParent,TreeChildren must be not null");
            }

            if (null == compartor && ConverterUtil.isNotEmpty(anSort)) {
                compartor = new TreeCompartor<T>(anSort);
            }

            // 首先找到根节点
            List<T> roots = findRoots(dirs, anId, anPid, anLevel, compartor);
            // 排除非根节点的全部元素
            List<T> notRoots = (List<T>) CollectionUtils.subtract(dirs, roots);
            if (ConverterUtil.isNotEmpty(roots)) {
                for (T root : roots) {
                    // 递归查找每个根节点的子节点列表
                    List<T> tmpChildren = findChildren(root, notRoots, anId, anPid, anLevel, anParent, anChildren, compartor);
                    // 找到了就说明有子节点
                    if (ConverterUtil.isNotEmpty(tmpChildren)) {
                        anChildren.set(root, tmpChildren);
                    }
                }
            }
            return roots;
        } catch (Exception e) {
            throw new RuntimeException("error in TreeUtils.listToTree", e);
        }
    }

    /**
     * 查询根节点
     *
     * @param allTreeNodes
     * @param anId
     * @param anPid
     * @param anLevel
     * @param compartor
     * @return
     */
    public static <T> List<T> findRoots(List<T> allTreeNodes, Field anId, Field anPid, Field anLevel, Comparator<T> compartor) {
        List<T> results = new ArrayList<T>();
        try {
            for (T node : allTreeNodes) {
                boolean isRoot = true;
                for (T comparedOne : allTreeNodes) {
                    Object nodePid = anPid.get(node);
                    if (ConverterUtil.isNotEmpty(nodePid) && nodePid.equals(anId.get(comparedOne))) {
                        isRoot = false;
                        break;
                    }
                }
                if (isRoot) {
                    if (ConverterUtil.isNotEmpty(anLevel)) {
                        anLevel.set(node, 0);
                    }
                    results.add(node);
                    // node.setRootId(node.getId());
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("error in TreeUtils.findRoots", e);
        }
        if (ConverterUtil.isNotEmpty(compartor)) {
            Collections.sort(results, compartor);
        }
        return results;
    }

    /**
     * 查询子节点
     *
     * @param root
     * @param allTreeNodes
     * @param anId
     * @param anPid
     * @param anLevel
     * @param anParent
     * @param anChildren
     * @param compartor
     * @return
     */
    @SuppressWarnings("unchecked")
    public static <T> List<T> findChildren(T root, List<T> allTreeNodes, Field anId, Field anPid, Field anLevel, Field anParent, Field anChildren, Comparator<T> compartor) {
        List<T> children = new ArrayList<T>();
        try {
            for (T comparedOne : allTreeNodes) {
                Object onePid = anPid.get(comparedOne);
                if (ConverterUtil.isNotEmpty(onePid) && onePid.equals(anId.get(root))) {
                    anParent.set(comparedOne, root);
                    if (ConverterUtil.isNotEmpty(anLevel)) {
                        anLevel.set(comparedOne, ConverterUtil.toInteger(anLevel.get(root)) + 1);
                    }
                    children.add(comparedOne);
                }
            }
            List<T> notChildren = (List<T>) CollectionUtils.subtract(allTreeNodes, children);
            if (ConverterUtil.isNotEmpty(children)) {
                for (T child : children) {
                    List<T> tmpChildren = findChildren(child, notChildren, anId, anPid, anLevel, anParent, anChildren, compartor);
                    if (ConverterUtil.isNotEmpty(tmpChildren)) {
                        anChildren.set(child, tmpChildren);
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("error in TreeUtils.findChildren", e);
        }
        if (ConverterUtil.isNotEmpty(compartor)) {
            Collections.sort(children, compartor);
        }
        return children;
    }
}
